// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "subgraph_tests/matmul_multiply_fusion.hpp"

using namespace ov::test;

namespace {
std::vector<MatMulMultiplyFusionShapeParams> shape_params = {
    {{2, 2}, {2, 2}, false, {}},
    {{2, 2}, {2, 2}, false, {1}},
    {{2, 2}, {2, 2}, false, {1, 2}},
    {{2, 2}, {2, 2}, true, {1, 2}},
    {{5}, {5}, false, {}},
    {{5}, {5, 1}, false, {}},
    {{5}, {5, 1}, false, {1}},
    {{5}, {5, 3}, false, {3}},
    {{5}, {3, 5}, true, {3}},
    {{5, 10}, {10, 7}, false, {}},
    {{5, 10}, {7, 10}, true, {}},
    {{5, 10}, {10, 7}, false, {7}},
    {{5, 10}, {7, 10}, true, {7}},
    {{5, 10}, {10, 7}, false, {1, 7}},
    {{5, 10}, {7, 10}, true, {1, 7}},
    {{5, 10}, {2, 10, 7}, false, {2, 1, 7}},
    {{5, 10}, {2, 7, 10}, true, {2, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {7}},
    {{5, 10}, {2, 3, 7, 10}, true, {7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 1, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 3, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 3, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {2, 3, 1, 7}},
    {{5, 10}, {10}, false, {}},
    {{5, 10}, {10}, false, {1}},
    {{2, 3, 5, 10}, {10, 7}, false, {1}},
    {{2, 3, 5, 10}, {7, 10}, true, {1}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {1, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {10, 7}, false, {1, 1, 1, 7}},
    {{1, 1, 5, 10}, {7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {2, 3, 1, 7}},
};

INSTANTIATE_TEST_SUITE_P(smoke_MatMulMultiplyFusion,
                         MatMulMultiplyFusion,
                         ::testing::Combine(::testing::ValuesIn(shape_params),
                                            ::testing::Values(true),  // can be fused
                                            ::testing::Values(ov::test::utils::DEVICE_CPU)),
                         MatMulMultiplyFusion::getTestCaseName);

std::vector<MatMulMultiplyFusionShapeParams> negative_shape_params = {
    {{5}, {5}, false, {1}},
    {{5}, {5}, false, {5}},
    {{5}, {5}, false, {5, 1}},
    {{5}, {5, 3}, false, {1, 3}},
    {{2, 2}, {2, 2}, false, {2, 2}},
    {{2, 2}, {2, 2}, true, {2, 2}},
    {{5, 5}, {5, 5}, false, {5, 5}},
    {{5, 5}, {5, 5}, true, {5, 5}},
    {{5, 10}, {10}, false, {5, 1}},
    {{5, 10}, {10, 7}, false, {5, 7}},
    {{5, 10}, {7, 10}, true, {5, 7}},
    {{5, 10}, {10, 5}, false, {5, 5}},
    {{5, 10}, {5, 10}, true, {5, 5}},
    {{1, 1, 5, 10}, {10, 7}, false, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {7, 10}, true, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {1, 1, 10, 7}, false, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {1, 1, 7, 10}, true, {2, 3, 1, 7}},
    {{2, 1, 5, 10}, {1, 1, 10, 7}, false, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {1, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {10}, false, {5}},
    {{2, 3, 5, 10}, {10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {1, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {1, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 1, 7}},
};

INSTANTIATE_TEST_SUITE_P(smoke_NegativeMatMulMultiplyFusion,
                         MatMulMultiplyFusion,
                         ::testing::Combine(::testing::ValuesIn(negative_shape_params),
                                            ::testing::Values(false),  // cannot be fused
                                            ::testing::Values(ov::test::utils::DEVICE_CPU)),
                         MatMulMultiplyFusion::getTestCaseName);

std::vector<MatMulMultiplyFusionShapeParams> shape_params2 = {
    {{2, 2}, {2, 2}, false, {}},
    {{2, 2}, {2, 2}, false, {1}},
    {{2, 2}, {2, 2}, false, {1, 2}},
    {{2, 2}, {2, 2}, true, {1, 2}},
    {{5, 10}, {10, 7}, false, {}},
    {{5, 10}, {7, 10}, true, {}},
    {{5, 10}, {10, 7}, false, {7}},
    {{5, 10}, {7, 10}, true, {7}},
    {{5, 10}, {10, 7}, false, {1, 7}},
    {{5, 10}, {7, 10}, true, {1, 7}},
    {{5, 10}, {2, 10, 7}, false, {2, 1, 7}},
    {{5, 10}, {2, 7, 10}, true, {2, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {7}},
    {{5, 10}, {2, 3, 7, 10}, true, {7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 1, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {1, 3, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {1, 3, 1, 7}},
    {{5, 10}, {2, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{5, 10}, {2, 3, 7, 10}, true, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {1}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {1, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {1, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {2, 3, 7, 10}, true, {2, 3, 1, 7}},
};

INSTANTIATE_TEST_SUITE_P(smoke_QuantizedMatMulMultiplyFusion,
                         QuantizedMatMulMultiplyFusion,
                         ::testing::Combine(::testing::ValuesIn(shape_params2),
                                            ::testing::Values(true),  // can be fused
                                            ::testing::Values(ov::test::utils::DEVICE_CPU)),
                         QuantizedMatMulMultiplyFusion::getTestCaseName);

std::vector<MatMulMultiplyFusionShapeParams> negative_shape_params2 = {
    {{2, 2}, {2, 2}, false, {2, 2}},
    {{2, 2}, {2, 2}, true, {2, 2}},
    {{5, 5}, {5, 5}, false, {5, 5}},
    {{5, 5}, {5, 5}, true, {5, 5}},
    {{5, 10}, {10, 7}, false, {5, 7}},
    {{5, 10}, {7, 10}, true, {5, 7}},
    {{5, 10}, {10, 5}, false, {5, 5}},
    {{5, 10}, {5, 10}, true, {5, 5}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 1}},
    {{1, 1, 5, 10}, {10, 7}, false, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {7, 10}, true, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 1, 1}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {1, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {10, 7}, false, {2, 3, 1, 7}},
    {{2, 3, 5, 10}, {7, 10}, true, {2, 3, 1, 7}},
    {{1, 1, 5, 10}, {10, 7}, false, {1, 1, 1, 7}},
    {{1, 1, 5, 10}, {7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {1, 1, 1, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {1, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {1, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 10, 7}, false, {2, 3, 5, 7}},
    {{2, 3, 5, 10}, {3, 7, 10}, true, {2, 3, 5, 7}},
};

INSTANTIATE_TEST_SUITE_P(smoke_NegativeQuantizedMatMulMultiplyFusion,
                         QuantizedMatMulMultiplyFusion,
                         ::testing::Combine(::testing::ValuesIn(negative_shape_params2),
                                            ::testing::Values(false),  // cannot be fused
                                            ::testing::Values(ov::test::utils::DEVICE_CPU)),
                         QuantizedMatMulMultiplyFusion::getTestCaseName);

}  // namespace
